import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
import time

np.random.seed(42)

class RepresentationDynamics:
    def __init__(self, dim=3, gamma0=0.1, rho=0.3, kappa=0.2, delta=0.05,
                 E0=1.0, noise_strength=0.1, dt=0.01, T=20.0):
        self.dim = dim  # Dimension of representation space
        self.gamma0 = gamma0  # Baseline damping coefficient
        self.rho = rho  # Sensitivity to metabolic availability
        self.kappa = kappa  # Energy depletion rate
        self.delta = delta  # Self-regulatory damping coefficient
        self.E0 = E0  # Initial energy
        self.noise_strength = noise_strength  # Maximum noise amplitude
        self.dt = dt  # Time step for simulation
        self.T = T  # Total simulation time
        self.L = 0.5  # Lipschitz constant for phi function
        # Initialize state variables
        self.Rep = np.random.normal(0, 0.1, size=dim)  # Initial representation
        self.E = E0  # Initial energy level
        # For tracking
        self.history = {'t': [], 'Rep': [], 'E': [], 'V': [], 'gamma': [], 'stability': []}

    def potential_gradient(self, Rep):
        """Gradient of the potential function V(Rep) = 0.5 * ||Rep||^2"""
        return Rep

    def phi(self, Rep):
        """Coupling function between noise and representation"""
        return self.L * np.tanh(Rep)

    def gamma(self, E, Rep):
        """Active damping term with metabolic and self-regulatory components"""
        return self.gamma0 + self.rho * np.exp(-self.kappa * E) + self.delta * np.linalg.norm(Rep)

    def energy_dynamics(self, E, Rep):
        """Energy decay dynamics"""
        # Energy decays faster when representations are highly active
        return -self.kappa * E * (1 + 0.1 * np.linalg.norm(Rep)**2)

    def rep_dynamics(self, t, Rep, E, noise):
        """Representation dynamics according to lemma I.B"""
        grad_V = self.potential_gradient(Rep)
        gamma_t = self.gamma(E, Rep)
        noise_coupling = 2 * self.phi(Rep) * noise  # Boosted noise coupling
        return -grad_V - gamma_t * Rep + noise_coupling

    def lyapunov(self, Rep, grad_V, gamma, noise_term):
        """Calculate dV/dt for Lyapunov function V = 0.5 * ||Rep||^2"""
        V_dot = -np.dot(Rep, grad_V) - gamma * np.linalg.norm(Rep)**2 + np.dot(Rep, noise_term)
        return V_dot

    def stability_measure(self, Rep, E, noise):
        """Calculate stability measure alpha(t) - beta"""
        grad_V = self.potential_gradient(Rep)
        gamma_t = self.gamma(E, Rep)
        alpha = gamma_t - np.linalg.norm(grad_V)
        beta = self.L * self.noise_strength * np.sqrt(self.E)  # Use current E
        return alpha - beta  # Simplified, no division by ||Rep||

    def simulate_step(self, t, print_interval=1.0):
        """Simulate a single time step"""
        # Generate noise for this step
        noise = np.random.normal(0, self.noise_strength, size=self.dim)
        # Update representation
        grad_V = self.potential_gradient(self.Rep)
        gamma_t = self.gamma(self.E, self.Rep)
        noise_term = self.phi(self.Rep) * noise
        dRep = self.rep_dynamics(t, self.Rep, self.E, noise)
        self.Rep += dRep * self.dt
        # Update energy
        dE = self.energy_dynamics(self.E, self.Rep)
        self.E += dE * self.dt
        self.E = max(0.01, self.E)  # Ensure energy stays positive
        # Calculate Lyapunov function and its derivative
        V = 0.5 * np.linalg.norm(self.Rep)**2
        V_dot = self.lyapunov(self.Rep, grad_V, gamma_t, noise_term)
        # Calculate stability measure
        stability = self.stability_measure(self.Rep, self.E, noise)
        # Store results
        self.history['t'].append(t)
        self.history['Rep'].append(self.Rep.copy())
        self.history['E'].append(self.E)
        self.history['V'].append(V)
        self.history['gamma'].append(gamma_t)
        self.history['stability'].append(stability)
        # Print results at regular intervals
        if int(t / print_interval) != int((t - self.dt) / print_interval):
            rep_norm = np.linalg.norm(self.Rep)
            print(f"t={t:.2f}, ||Rep||={rep_norm:.4f}, E={self.E:.4f}, gamma={gamma_t:.4f}, "
                  f"V={V:.4f}, dV/dt={V_dot:.4f}, stability={stability:.4f}")
            if stability > 0:
                print("✓ Stability condition satisfied: system is converging")
            else:
                print("✗ Stability condition not satisfied: system may be unstable")

    def run_simulation(self):
        """Run the full simulation"""
        print("\n=== Stability of Representation Dynamics (lemma I.B) ===")
        print("System parameters:")
        print(f"- Dimension: {self.dim}")
        print(f"- Baseline damping (gamma0): {self.gamma0}")
        print(f"- Metabolic sensitivity (rho): {self.rho}")
        print(f"- Energy depletion rate (kappa): {self.kappa}")
        print(f"- Self-regulatory damping (delta): {self.delta}")
        print(f"- Initial energy (E0): {self.E0}")
        print(f"- Noise strength: {self.noise_strength}")
        print(f"- Lipschitz constant (L): {self.L}")

        print("\nRunning simulation...")
        print("Time    ||Rep||    Energy    Damping    Lyapunov    dV/dt    Stability")
        print("--------------------------------------------------------------------------------")
        t = 0
        print_interval = 1.0
        start_time = time.time()
        while t < self.T:
            self.simulate_step(t, print_interval)
            t += self.dt
        runtime = time.time() - start_time

        # Convert lists to numpy arrays for analysis
        for key in ['Rep', 'E', 'V', 'gamma', 'stability']:
            if key == 'Rep':
                self.history[key] = np.array(self.history[key])
            else:
                self.history[key] = np.array(self.history[key])

        print("\n=== Simulation Complete ===")
        print(f"Runtime: {runtime:.2f} seconds")
        self.analyze_results()

    def analyze_results(self):
        """Analyze and report on simulation results"""
        # Calculate Rep norms
        rep_norms = np.array([np.linalg.norm(rep) for rep in self.history['Rep']])
        initial_norm = rep_norms[0]
        final_norm = rep_norms[-1]
        max_norm = np.max(rep_norms)
        # Calculate average stability measure
        stability_measures = np.array(self.history['stability'])
        avg_stability = np.mean(stability_measures)
        min_stability = np.min(stability_measures)
        # Calculate theoretical bound from lemma I.B
        beta = self.L * self.noise_strength * np.sqrt(self.E0)
        epsilon = self.gamma0  # Minimum value of alpha(t)
        theoretical_bound = initial_norm * np.exp(-epsilon * self.T) + beta / epsilon

        print("\n=== Results Analysis ===")
        print(f"Initial ||Rep||: {initial_norm:.4f}")
        print(f"Final ||Rep||: {final_norm:.4f}")
        print(f"Maximum ||Rep||: {max_norm:.4f}")
        print(f"Theoretical bound (from lemma I.B): {theoretical_bound:.4f}")
        print(f"\nAverage stability measure: {avg_stability:.4f}")
        print(f"Minimum stability measure: {min_stability:.4f}")

        # Test for lemma I.B consistency
        if max_norm <= theoretical_bound * 1.1:  # Allow 10% margin for numerical errors
            print("\n✓ CONFIRMED: Representations remained within theoretical bounds (lemma I.B)")
        else:
            print("\n✗ INCONSISTENT: Representations exceeded theoretical bounds")
        if min_stability > 0:
            print("✓ CONFIRMED: System maintained stability throughout simulation")
        else:
            stable_percentage = np.mean(stability_measures > 0) * 100
            print(f"✗ PARTIAL STABILITY: System was stable {stable_percentage:.1f}% of the time")

        # Test the role of metabolic damping
        gamma_values = np.array(self.history['gamma'])
        energy_values = np.array(self.history['E'])
        energy_gamma_corr = np.corrcoef(energy_values, gamma_values)[0, 1]
        print(f"\nCorrelation between energy and damping: {energy_gamma_corr:.4f}")
        if energy_gamma_corr < 0:
            print("✓ CONFIRMED: Metabolic depletion increased damping as expected")
        else:
            print("✗ UNEXPECTED: Metabolic depletion did not increase damping as expected")

        print("\n=== Summary ===")
        if avg_stability > 0 and max_norm <= theoretical_bound * 1.1:
            print("The simulation CONFIRMS the stability properties described in lemma I.B.")
            print("Representations remained bounded under metabolic constraints and noise perturbations.")
        else:
            print("The simulation shows PARTIAL SUPPORT for lemma I.B.")
            print("Some aspects of the theoretical predictions were not fully observed.")


def test_lemma2_prediction(title, description, expected_outcome, test_function):
    """Utility function to test specific aspects of lemma I.B"""
    print(f"\n=== Testing: {title} ===")
    print(f"Description: {description}")
    print(f"Expected outcome: {expected_outcome}")
    result = test_function()
    if result:
        print(f"✓ TEST PASSED: The simulation confirms this aspect of lemma I.B")
    else:
        print(f"✗ TEST FAILED: The simulation does not support this aspect of lemma I.B")
    return result


def test_bounded_representations():
    """Test bounded representations under perturbations"""
    print("\nRunning test for bounded representations...")
    sim = RepresentationDynamics(noise_strength=0.3, T=15.0)
    sim.run_simulation()
    rep_norms = [np.linalg.norm(rep) for rep in sim.history['Rep']]
    beta = sim.L * sim.noise_strength * np.sqrt(sim.E0)
    epsilon = sim.gamma0
    theoretical_bound = rep_norms[0] * np.exp(-epsilon * sim.T) + beta / epsilon
    max_norm = max(rep_norms)
    print(f"Maximum ||Rep||: {max_norm:.4f}, Theoretical bound: {theoretical_bound:.4f}")
    return max_norm <= theoretical_bound * 1.2  # Allow 20% margin for stochastic effects


def test_metabolic_damping():
    """Test increased damping under energy depletion"""
    print("\nRunning test for metabolic damping effect...")
    sim = RepresentationDynamics(kappa=0.5, rho=0.8, E0=2.0, T=15.0)  # Faster energy depletion
    sim.run_simulation()
    energies = np.array(sim.history['E'])
    gammas = np.array(sim.history['gamma'])
    early_energy = np.mean(energies[:len(energies) // 4])
    late_energy = np.mean(energies[-len(energies) // 4:])
    early_gamma = np.mean(gammas[:len(gammas) // 4])
    late_gamma = np.mean(gammas[-len(gammas) // 4:])
    print(f"Early energy: {early_energy:.4f}, Early damping: {early_gamma:.4f}")
    print(f"Late energy: {late_energy:.4f}, Late damping: {late_gamma:.4f}")
    return late_gamma > early_gamma and late_energy < early_energy


def test_lyapunov_function():
    """Test Lyapunov function behavior"""
    print("\nRunning test for Lyapunov function behavior...")
    sim = RepresentationDynamics(T=10.0)
    sim.run_simulation()
    V_values = [0.5 * np.linalg.norm(rep)**2 for rep in sim.history['Rep']]
    v_final_mean = np.mean(V_values[-len(V_values) // 5:])
    v_final_std = np.std(V_values[-len(V_values) // 5:])
    print(f"Final V(t) mean: {v_final_mean:.4f}, std: {v_final_std:.4f}")
    print(f"Coefficient of variation: {v_final_std / v_final_mean:.4f}")
    return v_final_std < 0.001 and v_final_mean < 0.01  # Absolute bounds


def main():
    """Main function to execute the script"""
    print("\n=== Running Default Simulation ===")
    sim1 = RepresentationDynamics()
    sim1.run_simulation()

    print("\n" + "=" * 80)
    print("TESTING DIFFERENT METABOLIC SCENARIOS")
    print("=" * 80)

    print("\n=== Scenario 1: Higher Metabolic Sensitivity ===")
    sim2 = RepresentationDynamics(rho=0.6)  # Increased sensitivity to metabolic availability
    sim2.run_simulation()

    print("\n=== Scenario 2: Increased Noise with Higher Self-Regulation ===")
    sim3 = RepresentationDynamics(noise_strength=0.2, delta=0.1)
    sim3.run_simulation()

    print("\n=== Scenario 3: Low Initial Energy ===")
    sim4 = RepresentationDynamics(E0=0.5, gamma0=0.15)  # Low energy, higher baseline damping
    sim4.run_simulation()

    print("\n" + "=" * 80)
    print("TESTING THE THEORETICAL PREDICTIONS OF lemma I.B")
    print("=" * 80)

    test_results = []
    test_results.append(test_lemma2_prediction(
        "Bounded Representations",
        "lemma I.B predicts that representations remain bounded under noise perturbations",
        "||Rep(t)|| should stay below the theoretical bound",
        test_bounded_representations
    ))
    test_results.append(test_lemma2_prediction(
        "Metabolic Damping Effect",
        "As energy depletes, the system should increase damping to maintain stability",
        "Damping (gamma) should increase as energy decreases",
        test_metabolic_damping
    ))
    test_results.append(test_lemma2_prediction(
        "Lyapunov Function Behavior",
        "The Lyapunov function V(t) = 0.5 * ||Rep||^2 should stabilize",
        "V(t) should converge to a bounded region with small fluctuations",
        test_lyapunov_function
    ))

    print("\n" + "=" * 80)
    print("FINAL SUMMARY OF lemma I.B VALIDATION")
    print("=" * 80)
    print(f"\nTests passed: {sum(test_results)}/{len(test_results)}")
    if all(test_results):
        print("\n✓ STRONG VALIDATION: All tests passed, confirming the stability properties in lemma I.B.")
        print("Representation dynamics remain stable under metabolic constraints and perturbations.")
        print("The simulation confirms that:")
        print("1. Representations remain bounded under noise perturbations")
        print("2. Metabolic depletion triggers increased damping")
        print("3. The Lyapunov function stabilizes, indicating convergence to a bounded region")
    elif sum(test_results) >= len(test_results) / 2:
        print("\n△ PARTIAL VALIDATION: Some tests passed, providing partial support for lemma I.B.")
        print("The simulation shows that some aspects of the theory hold, but with limitations.")
    else:
        print("\n✗ WEAK VALIDATION: Few tests passed, suggesting potential issues with the predictions of lemma I.B.")
        print("The simulation indicates that the theoretical predictions may need refinement.")

    print("\nThis concludes the validation of lemma I.B: Stability of Representation Dynamics")


if __name__ == "__main__":
    main()